import os.path as osp
import ipdb
from tqdm import tqdm
import argparse
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, DeepGraphInfomax
from torch_geometric.data import GraphSAINTRandomWalkSampler, NeighborSampler
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import subgraph
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data import DataLoader
import scipy.sparse as ss
import numpy as np


class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(Encoder, self).__init__()
        self.conv = SAGEConv(in_channels, hidden_channels)
        self.prelu = nn.PReLU(hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = self.prelu(x)
        return x


def corruption(x, edge_index):
    return x[torch.randperm(x.size(0))], edge_index


def train(model, data, optimizer):
    model.train()
    
    optimizer.zero_grad()
    pos_z, neg_z, summary = model(data.x, data.edge_index)
    loss = model.loss(pos_z, neg_z, summary)
    loss.backward()
    optimizer.step()
    
    return loss

@torch.no_grad()
def test(model, data, split_idx, SAVEPATH=None, Emb_only=False):
    model.eval()
    z, _, _ = model(data.x, data.edge_index)
#     z = inference(model, data.x, subgraph_loader, device)
    if SAVEPATH is not None:
        torch.save(z,SAVEPATH)
# '/home/ec2-user/Eli/SSL_baselines/DGI_embedding/OGB_feature.pt'
    if not Emb_only:
        acc = model.test(z[split_idx['train']], data.y[split_idx['train']].view(-1),
                         z[split_idx['test']], data.y[split_idx['test']].view(-1), max_iter=500)
        print('Accuracy: {:.4f}'.format(acc))
        return acc
    else:
        print('Skip Logistic regression. Save Emb only.')
        return

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--hidden_channels', type=int, default=768)
    parser.add_argument('--save_name', type=str, default='OGB_feature')
    parser.add_argument('--input_feature_path', type=str, default='None')
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--device', type=int, default=0)
    args = parser.parse_args()

    SAVEPATH = './DGI_embedding/input_{}.pt'.format(args.save_name)

    dataset = PygNodePropPredDataset(name =  "ogbn-arxiv", root = "../../dataset")
    split_idx = dataset.get_idx_split()
    data = dataset[0]

    # Replace node features here
    if args.input_feature_path != 'None':
        data.x = torch.tensor(np.load(args.input_feature_path))
        print("Pretrained node features loaded! Path: {}".format(args.input_feature_path))

    cuda = args.device
    device = torch.device('cuda:'+str(cuda))
    data = data.to(device)

    model = DeepGraphInfomax(
        hidden_channels=args.hidden_channels, encoder=Encoder(data.x.size(1), args.hidden_channels),
        summary=lambda z, x, edge_index: torch.sigmoid(z.mean(dim=0)),
        corruption=corruption).to(device)




    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


    for epoch in range(1, args.epochs + 1):
        loss = train(model, data, optimizer)
        print('Epoch: {:03d}, Loss: {:.4f}'.format(epoch, loss))
    acc = test(model, data, split_idx, SAVEPATH,Emb_only=True)

if __name__ == "__main__":
    main()
